Super resolution on medical images:

Set-up:

Imports:
In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import pandas as pd
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *
import tensorflow as tf
from torchvision.models import vgg16_bn
import torchvision.transforms as transforms
from skimage import measure

from  modified_model_fx import *
In [2]:
import warnings
warnings.filterwarnings('ignore')
In [3]:
trans1 = transforms.ToTensor()
trans = transforms.ToPILImage()

Set CUDA:

In [4]:
torch.cuda.set_device(3)

Import data:

Input images:

  • HR images: (3, 502, 672) original (3, 1004, 1344) cut in 4 pieces
  • LR images: (3, 250, 334) original (3, 500, 669) cut in 4 pieces

Paths to data:

In [5]:
path = Path('../../../../../SCRATCH2/marvande/data/train/HR/')

# path to HR data:
path_hr = path / 'HR_patches_train/tiff_files'

# path where we create the LR data:
path_lr = path / 'small-250/train'

# path to MR data (of same size as HR):
path_mr = path / 'small-502/train'

# path to original LR data:
path_lr_or = path / 'HR_patches_resized/jpg_images'

assert path.exists(), f"need dataset @ {path}"
assert path_hr.exists()

Have a look at what type of data is in those folders. First, we look at our HR images from path_hr.

In [6]:
il = ImageList.from_folder(path_hr)
PIL.Image.open(
    '../../../../../SCRATCH2/marvande/data/train/HR/HR_patches_train/tiff_files/0124_[47360,11368]_part_1_1_.tif'
)
Out[6]:

Create LR and MR data:

From this HR data, we create LR versions in path_lr and path_mr for training. LR images will be of size (3, 250, 334) while MR images of the same size as the HR but of lower quality (MR images are used in the second phase of training), thus (3, 502, 672).

In [7]:
def resize_one(fn, i, path, size):
    """resize_one: resizes images to input size and saves them in path, 
    quality is lowered as to get LR images. 
    
    """
    dest = path / fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    img = img.resize(size, resample=PIL.Image.BICUBIC).convert('RGB')
    img.save(dest, quality=60)
In [8]:
# create smaller image sets of lower quality the first time this nb is run:
sets = [(path_lr, (334, 250)), (path_mr, (672, 502))]
for p, size in sets:
    if not p.exists():
        print(f"resizing to {size} into {p}")
        parallel(partial(resize_one, path=p, size=size), il.items)

Have a look at the new LR and MR images and their sizes.

In [9]:
print('LR:')
print(ImageList.from_folder(path_lr))
print('MR:')
print(ImageList.from_folder(path_mr))
LR:
ImageList (2056 items)
Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train
MR:
ImageList (2056 items)
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-502/train

Training and validation data:

As the model's first phase is on images of the LR size, we create the first batch of training data of size (3, 250, 334). So HR images are downsized to (3, 250, 334) and both HR and LR are transformed using several transformations (see modified_model_fx.py). Those data augmentation techniques help avoid overfitting during training.

In [10]:
# set image size and batch size to which data is transformed:
bs, size = 15, (250, 334)
arch = models.resnet34

src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.1, seed=42)

Change statistics.

In [11]:
def get_data(bs, size):
    """
    get_data: creates training and validation data from LR and HR. 
    downsizes HR to LR size and applies transformations to both. 
    """

    #label_from_func: apply func to every input to get its label.
    # defining a custom function to extract the labels
    data = src.label_from_func(lambda x: path_hr / x.relative_to(path_lr))

    #apply data transformations,
    #data = data.transform(tfms, size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats,do_y=True)
    data = data.transform(
        tfms, size=size, tfm_y=True).databunch(bs=bs).normalize(do_y=True)

    data.c = 3
    return data
In [12]:
data = get_data(bs, size)
data
Out[12]:
ImageDataBunch;

Train: LabelList (1851 items)
x: ImageImageList
Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
y: ImageList
Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train;

Valid: LabelList (205 items)
x: ImageImageList
Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
y: ImageList
Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train;

Test: None

Have a look at a few data from the validation data. Left are the LR while right the corresponding Hr images.

In [13]:
# lr 
x = data.one_batch(ds_type=DatasetType.Valid)[0][4]
# hr
y = data.one_batch(ds_type=DatasetType.Valid)[1][4]
print('LR data shape {} and HR shape {}'.format(list(x.shape), list(y.shape)))

MSE = mse_mult_chann(x.numpy(), y.numpy())
NMSE = measure.compare_nrmse(x.numpy(), y.numpy())
SSIM = ssim_mult_chann(x.numpy(), y.numpy())
print('MSE: {}, NMSE: {}, SSIM: {}'.format(MSE, NMSE, SSIM))

plot_single_image(x, '', (10,10))
plot_single_image(y, '', (10,10))
LR data shape [3, 250, 334] and HR shape [3, 250, 334]
MSE: 0.00028784769793362835, NMSE: 0.01852989861384659, SSIM: 0.9801451255576129
In [14]:
data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(20, 20))

Feature loss:

Create loss metrics used in the model.

In [15]:
def gram_matrix(x):
    """
    Gram matrix of a set of vectors in an inner 
    product space is the Hermitian matrix of inner products
    """
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

Select the data of an image and use it to compute its Gram matrix.

In [16]:
im = data.valid_ds[0][1]
t = im.data
t = torch.stack([t,t])
In [17]:
gram_matrix(t)
Out[17]:
tensor([[[0.2957, 0.3039, 0.2875],
         [0.3039, 0.3125, 0.2959],
         [0.2875, 0.2959, 0.2840]],

        [[0.2957, 0.3039, 0.2875],
         [0.3039, 0.3125, 0.2959],
         [0.2875, 0.2959, 0.2840]]])

We define a base loss as the L1 loss (F.l1_loss).

In [18]:
base_loss = F.l1_loss

Construct a pre-trained vgg16 model (Very Deep Convolutional Networks for Large-Scale Image Recognition). VGG is another Convolutional Neural Network (CNN) architecture devised in 2014, the 16 layer version is utilised in the loss function for training this model. VGG model. a network pretrained on ImageNet, is used to evaluate the generator model’s loss.

Further, we set requires_grad to False as this is useful when you want to freeze part of your model, or you know in advance that you are not going to use gradients w.r.t. some parameters.

The head of the VGG model is the final layers shown as fully connected and softmax in the above diagram. This head is ignored and the loss function uses the intermediate activations in the backbone of the network, which represent the feature detections.

Those activations can be found by looking through the VGG model to find all the max pooling layers. These are where the grid size changes and features are detected. So we need to select those layers.

In [19]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

Select the layer IDs of MaxPool2d blocks:

In [20]:
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]
Out[20]:
([5, 12, 22, 32, 42],
 [ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True)])

Create the feature loss from the model and layer ids selected above.

Source:

Main points:

  • Loss functions using these techniques can be used during the training of U-Net based model architectures
  • Feature loss: loss function used is similar to the loss function in the the paper, using VGG-16 but also combined with pixel mean squared error loss loss and gram matrix style loss
  • The training of a model can use this loss function based on the VGG model’s activations. The loss function remains fixed throughout the training unlike the critic part of a GAN
  • Feature loss: Feature map has 256 channels by 28 by 28. The activations at the same layer for the (target) original image and the generated image are compared using mean squared error or the least absolute error (L1) error for the base loss. These are feature losses. This error function uses L1 error. This allows the loss function to know what features are in the target ground truth image and to evaluate how well the model’s prediction’s features match these rather than only comparing pixel difference. This allows the model being trained with this loss function to produce much finer detail in the generated/predicted features and output.
  • Gram matrix style loss: A gram matrix defines a style with respect to specific content. By calculating the gram matrix for each feature activation in the target/ground truth image, it allows the style of that feature to be defined. If the same gram matrix is calculated from the activations of the predictions, the two can be compared to calculate how close the style of the feature prediction is to the target/ground truth image. A gram matrix is the matrix multiplication of the each of the activations and the activation matrix’s transpose. This enables the model to learn and generate predictions of images whose features look correct in their style and in context, with the end result looking more convincing and appear closer or the same as the target/ground truth.

Predictions from models trained with this loss function: The generated predictions from trained models using loss functions based on these techniques have both convincing fine detail and style. That style and fine detail may be different aspects of image quality be predicting fine pixel detail or the predicting correct colours.

In [21]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        #feat losses
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        #gram: 
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()
In [22]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

Train

In [23]:
lr = 1e-3
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn.fit_one_cycle(10, lrs, pct_start=pct_start)
    learn.save(save_name, return_path=True)
    learn.show_results(rows=1, imgsize=10)
In [24]:
# delete all tensors and free cache:
for obj in gc.get_objects():
    if torch.is_tensor(obj):
        del obj
torch.cuda.empty_cache()
gc.collect()
#learn.destroy()

#get free memory (in MBs) for the currently selected gpu id, after emptying the cache
print(
    'free memory (in MBs) for the currently selected gpu id, after emptying the cache: ',
    gpu_mem_get_free_no_cache())

print(
    'used memory (in MBs) for the currently selected gpu id, after emptying the cache:',
    gpu_mem_get_used_no_cache())

gpu_mem_get_all()
free memory (in MBs) for the currently selected gpu id, after emptying the cache:  10933
used memory (in MBs) for the currently selected gpu id, after emptying the cache: 1279
Out[24]:
[GPUMemory(total=12212, free=11525, used=687),
 GPUMemory(total=12196, free=12185, used=10),
 GPUMemory(total=12196, free=12185, used=10),
 GPUMemory(total=12212, free=10933, used=1279)]
In [53]:
wd = 1e-3
learn = unet_learner(data,
                     arch,
                     wd=wd,
                     loss_func=feat_loss,
                     callback_fns=LossMetrics,
                     blur=True,
                     norm_type=NormType.Weight)
# garbage collection:
gc.collect()
Out[53]:
0
In [54]:
learn.lr_find()
learn.recorder.plot()
0.00% [0/1 00:00<00:00]
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time

70.73% [87/123 01:51<00:46 6.5438]
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

Phase 1:

In [55]:
lr = 1e-3
In [56]:
do_fit('1a_mod_norm', slice(lr * 10))
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time
0 1.008296 0.690632 0.161900 0.101998 0.095766 0.026347 0.154818 0.131999 0.017804 02:39
1 0.590003 0.569422 0.136047 0.081931 0.077113 0.021939 0.128694 0.108189 0.015507 02:38
2 0.500206 0.548980 0.124974 0.079605 0.074822 0.021268 0.129823 0.103643 0.014845 02:38
3 0.520237 0.571204 0.136269 0.079828 0.074964 0.021345 0.136374 0.107280 0.015143 02:38
4 0.519486 0.578849 0.128569 0.080207 0.076154 0.021809 0.144221 0.112383 0.015507 02:38
5 0.553103 0.624747 0.128062 0.082028 0.079325 0.023448 0.165605 0.128812 0.017467 02:37
6 0.505037 0.568536 0.130136 0.079478 0.074588 0.021436 0.141726 0.105886 0.015286 02:37
7 0.506419 0.593972 0.133298 0.079542 0.076337 0.022220 0.149033 0.117498 0.016043 02:37
8 0.490363 0.537218 0.119453 0.078438 0.073763 0.020848 0.127719 0.102670 0.014326 02:37
9 0.449800 0.501512 0.112042 0.074938 0.069593 0.019674 0.116367 0.095105 0.013793 02:37
In [62]:
! cp ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train/models/1a_mod_norm.pth ../../../../../SCRATCH2/marvande/data/train/HR/models/
In [57]:
learn.unfreeze()
In [58]:
learn.load('1a_mod_norm');
In [59]:
do_fit('1b_mod_norm', slice(1e-5,lr))
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time
0 0.434788 0.505183 0.112460 0.075113 0.069804 0.019747 0.118277 0.095920 0.013861 02:42
1 0.437264 0.491137 0.111246 0.074209 0.068692 0.019351 0.111386 0.092770 0.013483 02:43
2 0.438313 0.508856 0.112532 0.075450 0.070217 0.019881 0.120256 0.096566 0.013955 02:42
3 0.438041 0.513787 0.113332 0.075646 0.070520 0.020018 0.122361 0.097787 0.014123 02:43
4 0.432099 0.495587 0.111142 0.074471 0.069071 0.019492 0.113758 0.093999 0.013654 02:42
5 0.437167 0.495151 0.111800 0.074471 0.069048 0.019450 0.113194 0.093635 0.013553 02:43
6 0.433428 0.504120 0.111805 0.074743 0.069580 0.019716 0.117845 0.096523 0.013908 02:43
7 0.433134 0.520898 0.113892 0.076077 0.071057 0.020242 0.126031 0.099271 0.014327 02:43
8 0.429807 0.510807 0.112086 0.075466 0.070370 0.019988 0.121533 0.097290 0.014075 02:43
9 0.425064 0.509859 0.112178 0.075397 0.070223 0.019923 0.121188 0.096931 0.014020 02:43
In [63]:
! cp ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train/models/1b_mod_norm.pth ../../../../../SCRATCH2/marvande/data/train/HR/models/

Phase 2:

In [25]:
torch.cuda.empty_cache()
gc.collect()
#learn.destroy()

#get free memory (in MBs) for the currently selected gpu id, after emptying the cache
print(
    'free memory (in MBs) for the currently selected gpu id, after emptying the cache: ',
    gpu_mem_get_free_no_cache())

print(
    'used memory (in MBs) for the currently selected gpu id, after emptying the cache:',
    gpu_mem_get_used_no_cache())

gpu_mem_get_all()
free memory (in MBs) for the currently selected gpu id, after emptying the cache:  10933
used memory (in MBs) for the currently selected gpu id, after emptying the cache: 1279
Out[25]:
[GPUMemory(total=12212, free=11525, used=687),
 GPUMemory(total=12196, free=12185, used=10),
 GPUMemory(total=12196, free=12185, used=10),
 GPUMemory(total=12212, free=10933, used=1279)]
In [26]:
new_size = (502, 672)
bs = 4
data = get_data(bs, new_size)

wd = 1e-3
learn = unet_learner(
    data,
    arch,
    wd=wd,
    loss_func=feat_loss,
    callback_fns=LossMetrics,
    blur=True,
    norm_type=NormType.Weight,)

lr = 1e-3
# garbage collection:
gc.collect()
Out[26]:
0
In [27]:
data
Out[27]:
ImageDataBunch;

Train: LabelList (1851 items)
x: ImageImageList
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
y: ImageList
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train;

Valid: LabelList (205 items)
x: ImageImageList
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
y: ImageList
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train;

Test: None
In [28]:
new_size = (502, 672)
bs = 4
data = get_data(bs, new_size)
# lr 
x = data.one_batch(ds_type=DatasetType.Valid)[0][1]
# hr
y = data.one_batch(ds_type=DatasetType.Valid)[1][1]
print('LR data shape {} and HR shape {}'.format(list(x.shape), list(y.shape)))

MSE = mse_mult_chann(x.numpy(), y.numpy())
NMSE = measure.compare_nrmse(x.numpy(), y.numpy())
SSIM = ssim_mult_chann(x.numpy(), y.numpy())
print('MSE: {}, NMSE: {}, SSIM: {}'.format(MSE, NMSE, SSIM))

plot_single_image(x, '', (10,10))
plot_single_image(y, '', (10,10))
LR data shape [3, 502, 672] and HR shape [3, 502, 672]
MSE: 0.0005481507648623173, NMSE: 0.026690426222026973, SSIM: 0.937297089535451
In [29]:
learn.data = data
learn.freeze()
gc.collect()
Out[29]:
4799
In [30]:
learn.load('1b_mod_norm');
In [31]:
do_fit('2a_mod_norm')
80.00% [8/10 1:42:14<25:33]
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time
0 0.740985 0.920265 0.171159 0.125428 0.122721 0.040708 0.277791 0.155325 0.027134 12:56
1 0.696377 0.876691 0.170158 0.123192 0.119441 0.039550 0.251596 0.146586 0.026167 12:45
2 0.680609 0.822760 0.170623 0.119897 0.114547 0.037819 0.226394 0.129170 0.024311 12:45
3 0.667253 0.847775 0.170835 0.121531 0.116248 0.038580 0.237946 0.137996 0.024639 12:46
4 0.650902 0.819220 0.171181 0.119982 0.113914 0.037488 0.224214 0.128585 0.023857 12:45
5 0.667102 0.799208 0.174234 0.119610 0.112386 0.037067 0.211490 0.121501 0.022919 12:45
6 0.644349 0.811574 0.172244 0.120267 0.113188 0.037376 0.219929 0.124993 0.023577 12:45
7 0.639402 0.847978 0.172050 0.122212 0.115634 0.038252 0.239957 0.135474 0.024399 12:45

40.48% [187/462 04:56<07:15 0.6604]
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

In [37]:
! cp ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train/models/2a_mod_norm.pth ../../../../../SCRATCH2/marvande/data/train/HR/models/
In [33]:
learn.load('2a_mod_norm');
In [34]:
do_fit('2b_mod_norm', slice(1e-6,1e-4), pct_start=0.3)
50.00% [5/10 1:03:46<1:03:46]
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time
0 0.619233 0.787176 0.170035 0.119285 0.111301 0.036564 0.209400 0.117973 0.022618 12:45
1 0.620023 0.793075 0.169971 0.119717 0.111755 0.036718 0.212607 0.119577 0.022730 12:45
2 0.630182 0.786317 0.169930 0.119382 0.111219 0.036565 0.209282 0.117404 0.022535 12:45
3 0.621396 0.787919 0.169926 0.119398 0.111393 0.036638 0.209155 0.118806 0.022603 12:45
4 0.614185 0.792894 0.169873 0.119537 0.111600 0.036684 0.212626 0.119905 0.022670 12:45

90.48% [418/462 11:01<01:09 0.6182]
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

In [36]:
! cp ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train/models/2b_mod_norm.pth ../../../../../SCRATCH2/marvande/data/train/HR/models/

Testing :

Reconstruct the learn object, this time with l1_loss and not with feature_loss, why?

PSNR ? frechet inception distance FID score ?

In [38]:
learn = None
gc.collect()
learn = unet_learner(data,
                     arch,
                     loss_func=F.l1_loss,
                     blur=True,
                     norm_type=NormType.Weight)

Prediction:

Set the sizes of the LR input images (size_lr) for testing and the HR images (size_mr).

In [39]:
# path to test images:
path_test = path/'HR_patches_test/cut_images/jpeg_images/'

# path to save LR and MR test images:
path_lr_test = path / 'small-250/test'
path_mr_test = path / 'small-502/test'

#Check free GPU RAM:
free = gpu_mem_get_free_no_cache()
print(f"using size={size}, have {free} MB of GPU RAM free")
using size=(250, 334), have 7602 MB of GPU RAM free
In [40]:
! rm -r '../../../../../SCRATCH2/marvande/data/train/HR/small-250/test'
! rm -r '../../../../../SCRATCH2/marvande/data/train/HR/small-502/test'
rm: cannot remove '../../../../../SCRATCH2/marvande/data/train/HR/small-250/test': No such file or directory
rm: cannot remove '../../../../../SCRATCH2/marvande/data/train/HR/small-502/test': No such file or directory
In [41]:
def resize_test(fn, i, path, size):
    """resize_one: resizes images to input size and saves them in path, 
    quality is lowered as to get LR images. 
    
    """
    dest = path / fn.relative_to(path_test)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    img = img.resize(size, resample=PIL.Image.BICUBIC).convert('RGB')
    img.save(dest, quality=60)
In [42]:
il = ImageList.from_folder(path_test)
In [43]:
# create smaller image sets the first time this nb is run:
sets = [(path_lr_test, (334, 250)), (path_mr_test, (672, 502))]
for p, size in sets:
    if not p.exists():
        print(f"resizing to {size} into {p}")
        parallel(partial(resize_test, path=p, size=size), il.items)
resizing to (334, 250) into ../../../../../SCRATCH2/marvande/data/train/HR/small-250/test
resizing to (672, 502) into ../../../../../SCRATCH2/marvande/data/train/HR/small-502/test

Create testing data, data_mr of size HR and data_lr of size LR:

In [44]:
size_mr = (3, 502, 672)
size_lr = (3, 250, 334)

data_mr = (ImageImageList.from_folder(path_mr_test).split_by_rand_pct(0.1, seed=42)
          .label_from_func(lambda x: path_test/x.name)
          .transform(tfms, size=size_mr, tfm_y=True)
          .databunch(bs=1).normalize(do_y=True))
data_mr.c = 3

data_lr = (ImageImageList.from_folder(path_lr_test).split_by_rand_pct(0.1, seed=42)
          .label_from_func(lambda x: path_test/x.name)
          .transform(tfms, size=size_lr, tfm_y=True)
          .databunch(bs=1).normalize(do_y=True))
data_lr.c = 3


data_hr = (ImageImageList.from_folder(path_test).split_by_rand_pct(0.1, seed=42)
          .label_from_func(lambda x: path_test/x.name)
          .transform(tfms, size=size_mr, tfm_y=True)
          .databunch(bs=1).normalize(do_y=True))
data_hr.c = 3

data_lr, data_mr, data_hr 
Out[44]:
(ImageDataBunch;
 
 Train: LabelList (908 items)
 x: ImageImageList
 Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
 y: ImageList
 Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
 Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/test;
 
 Valid: LabelList (100 items)
 x: ImageImageList
 Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
 y: ImageList
 Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
 Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/test;
 
 Test: None,
 ImageDataBunch;
 
 Train: LabelList (908 items)
 x: ImageImageList
 Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
 y: ImageList
 Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
 Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-502/test;
 
 Valid: LabelList (100 items)
 x: ImageImageList
 Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
 y: ImageList
 Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
 Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-502/test;
 
 Test: None,
 ImageDataBunch;
 
 Train: LabelList (908 items)
 x: ImageImageList
 Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
 y: ImageList
 Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
 Path: ../../../../../SCRATCH2/marvande/data/train/HR/HR_patches_test/cut_images/jpeg_images;
 
 Valid: LabelList (100 items)
 x: ImageImageList
 Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
 y: ImageList
 Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
 Path: ../../../../../SCRATCH2/marvande/data/train/HR/HR_patches_test/cut_images/jpeg_images;
 
 Test: None)

Load the learn object from the last phase (2b_mod) and set it's data to data_mr (why?).

In [64]:
learn.load('2b_mod_norm')
# Here have to put mr if we want output the same size as HR:
learn.data = data_mr

Select the images we are going to use for testing from data_mr and data_lr:

In [65]:
# Ground truth HR image:
mr = data_mr.valid_ds.x.items[1]
im_mr = open_image(mr)

hr = data_mr.valid_ds.y.items[1]
#hr = data_hr.valid_ds.x.items[1]
im_HR_gt = open_image(hr)

# LR version of the same image:
lr = data_lr.valid_ds.x.items[1]
im_lr = open_image(lr)

# Check it's the same image:
mr, lr, hr
Out[65]:
(PosixPath('../../../../../SCRATCH2/marvande/data/train/HR/small-502/test/0131_[46740,17995]_part_4_-1.jpg'),
 PosixPath('../../../../../SCRATCH2/marvande/data/train/HR/small-250/test/0131_[46740,17995]_part_4_-1.jpg'),
 PosixPath('../../../../../SCRATCH2/marvande/data/train/HR/HR_patches_test/cut_images/jpeg_images/0131_[46740,17995]_part_4_-1.jpg'))

To be able to apply the model and predict the HR image from the LR, we need to resample the LR image to be of the same size as the HR, as the model inputs and outputs images of the same size.

In [62]:
# resample to same size as HR:
# for pytorch, have to add a first new dimension to indicate bs = 1
# now lr_data of shape [1, 3, 250, 334]
im_lr = im_lr.data.unsqueeze(0)
lr_resized = torch.nn.functional.interpolate(im_lr, size_mr[1:],
                                             mode='bicubic')
# remove the previous added dimension
lr_resized = lr_resized.squeeze()

input_image = lr_resized

# plot resized LR: 
plot_single_image(lr_resized, 
                  'Resized LR from size {} to size {}'\
                  .format(list(im_lr.squeeze().shape), 
                          list(lr_resized.shape)), (10,10))

Now that we have resized the LR to the same size as the ground truth HR, we can feed it to the model and predict a new HR image.

In [69]:
#print('HR ground thruth shape: {}'.format(list(im_HR_gt.shape)))

# Prediction of model: 
p, img_pred, b = learn.predict(Image(input_image))
print('Predicted HR shape: {}'.format(list(p.shape)))

# Assert reconstructed HR has same shape as ground truth HR:
assert(list(p.shape) == list(im_HR_gt.shape))
Predicted HR shape: [3, 502, 672]

Visualisation:

Compare this predicted image to the ground truth HR and LR:

In [70]:
plot_3_images(
    trans1(trans(im_mr.data)).numpy(),
    trans1(trans(im_HR_gt.data)).numpy(), img_pred, (20, 20))
In [71]:
print('PRed')
print(np.max(img_pred.numpy()))
print(np.min(img_pred.numpy()))

print('Ground truth HR')
print(np.max(trans1(trans(im_HR_gt.data)).numpy()))
print(np.min(trans1(trans(im_HR_gt.data)).numpy()))

print('LR bic')
print(np.max(input_image.numpy()))
print(np.min(input_image.numpy()))
PRed
1.0255638
-0.044366777
Ground truth HR
1.0
0.1254902
LR bic
1.0381817
0.006942164

min(max_HR, trans1(trans(im_HR_gt.data)).numpy())

max(min_HR, trans1(trans(im_HR_gt.data)).numpy())

In [72]:
max_HR = np.max(trans1(trans(im_HR_gt.data)).numpy())
min_HR = np.min(trans1(trans(im_HR_gt.data)).numpy())

test_im = np.clip(a_max = max_HR,a_min = min_HR, a = img_pred.numpy())

print(mse_mult_chann(img_pred.numpy(), trans1(trans(im_HR_gt.data)).numpy()))
mse_mult_chann(test_im, trans1(trans(im_HR_gt.data)).numpy())
0.0007655642339835681
Out[72]:
0.0007634917329585635
Evaluate loss and error metrics:
In [73]:
compare_images_metrics(img_pred.numpy(),
                       trans1(trans(im_HR_gt.data)).numpy(),
                       lr_resized.numpy(), '')
MSE: 0.00076556, NMSE: 0.02974234, SSIM : 0.9015

In [74]:
## Plot for several: 
LR_list = ImageList.from_folder(path_lr_test).items
HR_list = ImageList.from_folder(path_test).items
phenotypes = {'0': 'CD4', '1': 'CK','2': 'DAPI', '3': 'CD3', '4': 'FoxP3', '5': 'CD8'}

for i in range(len(LR_list[0:5])):
    lr_data = open_image(LR_list[i])
    
    # get file string
    pattern = "test\/(.*?)\.jpg"
    substring = re.search(pattern, str(LR_list[i])).group(1)
    file_name = substring+'.jpg'
    
    # get patient number:
    pattern = "test\/(.*?)_"
    patient = re.search(pattern, str(LR_list[i])).group(1)
    
    # get phenotype number:
    pattern2 = "_\-(.*?)\.jpg"
    phenotype_numb = re.search(pattern2,file_name).group(1)
    phenotype = phenotypes[phenotype_numb]
    
    # get location number:
    pattern2 = "\[(.*?)\]"
    location = re.search(pattern2,file_name).group(1)        
        
    # resample to same size as HR:
    # for pytorch, have to add a first new dimension to indicate bs = 1
    # now lr_data of shape [1, 3, 250, 334]
    lr_data = lr_data.data.unsqueeze(0)
    lr_resized = torch.nn.functional.interpolate(lr_data, size_mr[1:],
                                                     mode='bicubic')
    # remove the previous added dimension
    lr_resized = lr_resized.squeeze()

    # corresponding ground truth HR: 
    im_HR_gt = open_image(HR_list[i])
    #print('HR ground thruth shape: {}'.format(list(im_HR_gt.shape)))

    # Prediction of model: 
    p, img_pred, b = learn.predict(Image(lr_resized))
    #print('Reconstructed HR shape: {}'.format(list(p.shape)))
    # Assert reconstructed HR has same shape as ground truth HR:
    assert(list(p.shape) == list(im_HR_gt.shape))

    gt_HR_nd = trans1(trans(im_HR_gt.data)).numpy()
    pred_HR_nd = img_pred.numpy()
    
    compare_images_metrics(pred_HR_nd,gt_HR_nd, lr_resized.numpy(),'phenotype: {}, patient: {}, location: [{}]'.format(phenotype, patient, location))
MSE: 0.00054982, NMSE: 0.02757922, SSIM : 0.9386
phenotype: CD3, patient: 0131, location: [46833,16191]
MSE: 0.00062199, NMSE: 0.02816320, SSIM : 0.9436
phenotype: CK, patient: 0131, location: [42912,16664]
MSE: 0.00019524, NMSE: 0.01457474, SSIM : 0.9746
phenotype: CK, patient: 0131, location: [42242,13869]
MSE: 0.00035136, NMSE: 0.01986533, SSIM : 0.9656
phenotype: CD8, patient: 0131, location: [41114,18347]
MSE: 0.00274324, NMSE: 0.06269793, SSIM : 0.8125
phenotype: DAPI, patient: 0131, location: [45930,8132]

Loss and error metrics:

In [152]:
# Create test table with results:

def testing_images(path_lr, path_hr, show_im = False):
    phenotypes = {'0': 'CD4', '1': 'CK','2': 'DAPI', '3': 'CD3', '4': 'FoxP3', '5': 'CD8'}
    MSE,NMSE, SSIM = [], [],[]
    LR_list = ImageList.from_folder(path_lr).items
    HR_list = ImageList.from_folder(path_hr).items
    file_names, locations, patients, phenotypes_list = [], [], [], []
    LR_MSE, LR_NMSE, LR_SSIM= [], [], []
    
    for i in range(len(LR_list[0:20])):
        lr_data = open_image(LR_list[i])
        
        pattern = "test\/(.*?)\.jpg"
        substring = re.search(pattern, str(LR_list[i])).group(1)
        file_names.append(substring+'.jpg')
        file_name = substring+'.jpg'
        
        # get patient number:
        pattern = "test\/(.*?)_"
        patients.append(re.search(pattern, str(LR_list[i])).group(1))

        # get phenotype number:
        pattern2 = "_\-(.*?)\.jpg"
        phenotype_numb = re.search(pattern2,file_name).group(1)
        phenotypes_list.append(phenotypes[phenotype_numb])

        # get location number:
        pattern3 = "\[(.*?)\]"
        locations.append('[{}]'.format(re.search(pattern3,file_name).group(1))  )

        # resample to same size as HR:
        # for pytorch, have to add a first new dimension to indicate bs = 1
        # now lr_data of shape [1, 3, 250, 334]
        lr_data = lr_data.data.unsqueeze(0)
        lr_resized = torch.nn.functional.interpolate(lr_data, size_mr[1:],
                                                     mode='bicubic')
        # remove the previous added dimension
        lr_resized = lr_resized.squeeze()

        # corresponding ground truth HR: 
        im_HR_gt = open_image(HR_list[i])
        #print('HR ground thruth shape: {}'.format(list(im_HR_gt.shape)))

        # Prediction of model: 
        p, img_pred, b = learn.predict(Image(lr_resized))
        #print('Reconstructed HR shape: {}'.format(list(p.shape)))

        # Assert reconstructed HR has same shape as ground truth HR:
        assert(list(p.shape) == list(im_HR_gt.shape))

        gt_HR_nd = trans1(trans(im_HR_gt.data)).numpy()
        pred_HR_nd = img_pred.numpy()

        MSE.append(mse_mult_chann(gt_HR_nd, pred_HR_nd))
        NMSE.append(measure.compare_nrmse(gt_HR_nd, pred_HR_nd))
        SSIM.append(ssim_mult_chann(gt_HR_nd, pred_HR_nd))
        
        LR_MSE.append(mse_mult_chann(lr_resized.numpy(), pred_HR_nd))
        LR_NMSE.append(measure.compare_nrmse(lr_resized.numpy(), pred_HR_nd))
        LR_SSIM.append(ssim_mult_chann(lr_resized.numpy(), pred_HR_nd))

    return pd.DataFrame(data = {'file':file_names,'patient':patients, 
                                'location':locations, 
                                'phenotype':phenotypes_list,
                                'MSE':MSE, 
                                'NMSE':NMSE, 'SSIM':SSIM, 
                               'LR_MSE': LR_MSE, 
                               'LR_NMSE':LR_NMSE,
                               'LR_SSIM':LR_SSIM})
      
df = testing_images(path_lr_test, path_test)
df.to_csv('data/metrics_sim6.csv')
df
Out[152]:
file patient location phenotype MSE NMSE SSIM LR_MSE LR_NMSE LR_SSIM
0 0131_[46833,16191]_part_1_-3.jpg 0131 [46833,16191] CD3 0.000550 0.027579 0.938599 0.000104 0.012013 0.990852
1 0131_[42912,16664]_part_2_-1.jpg 0131 [42912,16664] CK 0.000622 0.028163 0.943642 0.000147 0.013674 0.990004
2 0131_[42242,13869]_part_1_-1.jpg 0131 [42242,13869] CK 0.000195 0.014575 0.974555 0.000043 0.006844 0.996291
3 0131_[41114,18347]_part_2_-5.jpg 0131 [41114,18347] CD8 0.000351 0.019865 0.965638 0.000082 0.009616 0.994835
4 0131_[45930,8132]_part_4_-2.jpg 0131 [45930,8132] DAPI 0.002743 0.062698 0.812546 0.000323 0.021545 0.980212
5 0131_[49577,17995]_part_2_-5.jpg 0131 [49577,17995] CD8 0.001798 0.049299 0.822908 0.000169 0.015122 0.986398
6 0131_[40968,17689]_part_3_-0.jpg 0131 [40968,17689] CD4 0.000257 0.016644 0.975341 0.000072 0.008799 0.995703
7 0131_[47748,18478]_part_4_-5.jpg 0131 [47748,18478] CD8 0.000355 0.019816 0.964302 0.000071 0.008884 0.995677
8 0131_[43076,8951]_part_2_-3.jpg 0131 [43076,8951] CD3 0.000923 0.032773 0.873107 0.000080 0.009636 0.993585
9 0131_[50072,7868]_part_3_-5.jpg 0131 [50072,7868] CD8 0.000614 0.025865 0.918710 0.000071 0.008788 0.995579
10 0131_[44467,14410]_part_1_-4.jpg 0131 [44467,14410] FoxP3 0.000439 0.022153 0.942869 0.000055 0.007806 0.995403
11 0131_[42242,13869]_part_2_-3.jpg 0131 [42242,13869] CD3 0.000349 0.019794 0.969343 0.000084 0.009703 0.995176
12 0131_[48485,17537]_part_3_-2.jpg 0131 [48485,17537] DAPI 0.002624 0.057923 0.767526 0.000164 0.014503 0.987774
13 0131_[43515,10518]_part_4_-0.jpg 0131 [43515,10518] CD4 0.003379 0.063029 0.795015 0.000170 0.014153 0.990024
14 0131_[40237,15610]_part_1_-5.jpg 0131 [40237,15610] CD8 0.000256 0.016816 0.969992 0.000057 0.007904 0.995417
15 0131_[40734,12991]_part_4_-3.jpg 0131 [40734,12991] CD3 0.000344 0.019154 0.949077 0.000029 0.005602 0.997724
16 0131_[39373,13898]_part_3_-5.jpg 0131 [39373,13898] CD8 0.007021 0.095680 0.596867 0.000252 0.018212 0.982572
17 0131_[46833,16191]_part_1_-2.jpg 0131 [46833,16191] DAPI 0.000174 0.013666 0.974689 0.000031 0.005802 0.996930
18 0131_[50072,7868]_part_1_-1.jpg 0131 [50072,7868] CK 0.001899 0.047619 0.833912 0.000145 0.013191 0.989682
19 0131_[39813,12991]_part_2_-4.jpg 0131 [39813,12991] FoxP3 0.002537 0.054177 0.709545 0.000064 0.008633 0.992425
In [153]:
av_metrics = pd.DataFrame(index = ['Average and median metrics'],
    data={
        'avg MSE': np.mean(df['MSE']),
        'med MSE': np.median(df['MSE']),
        'avg NMSE': np.mean(df['NMSE']),
        'med NMSE': np.median(df['NMSE']),
        'avg SSIM': np.mean(df['SSIM']),
        'med SSIM': np.median(df['SSIM']),
        'avg LR_MSE': np.mean(df['LR_MSE']),
        'med LR_MSE': np.median(df['LR_MSE']),
        'avg LR_NMSE': np.mean(df['LR_NMSE']),
        'med LR_NMSE': np.median(df['LR_NMSE']),
        'avg LR_SSIM': np.mean(df['LR_SSIM']),
        'med LR_SSIM': np.median(df['LR_SSIM'])
    })
av_metrics.transpose()
av_metrics.to_csv('data/av_metrics_sim6.csv')
In [154]:
# report average and median MSE, NMSE, SSIM: 
fig = plt.figure(figsize = (15, 6))
ax = fig.add_subplot(1, 3, 1)
ax.plot(df['MSE'],  '-x', label='MSE prediction vs GT')
ax.plot(df['LR_MSE'],'-x', label = 'MSE interpolated HR vs GT')
ax.axhline(np.median(df['LR_MSE']),color = 'r')
ax.axhline(np.median(df['MSE']),color = 'r')
plt.xlabel('files')
plt.ylabel('MSE')
ax.legend()

ax = fig.add_subplot(1, 3, 2)
ax.plot(df['NMSE'],  '-x', label='NMSE prediction vs GT')
ax.plot(df['LR_NMSE'],'-x', label = 'NMSE interpolated HR vs GT')
ax.axhline(np.median(df['LR_NMSE']),color = 'r')
ax.axhline(np.median(df['NMSE']),color = 'r')
plt.xlabel('files')
plt.ylabel('NMSE')
ax.legend()

ax = fig.add_subplot(1, 3, 3)
ax.plot(df['SSIM'],  '-x', label='SSIM prediction vs GT')
ax.plot(df['LR_SSIM'],'-x', label = 'SSIM interpolated HR vs GT')
ax.axhline(np.median(df['LR_SSIM']),color = 'r')
ax.axhline(np.median(df['SSIM']),color = 'r')
plt.xlabel('files')
plt.ylabel('SSIM')
ax.legend()

plt.suptitle('Interpolated HR vs GT')
plt.savefig('images/intHRvsGT_sim6.png')
In [ ]: